# Загружаем pytorch для работы с нейронными сетями
import torch
import torch.nn as nn
import torch.nn.functional as F
# Для работы с изображениями/графиками
from torchvision import transforms
# Загружаем способы интерполяции изображений
from torchvision.transforms.functional import InterpolationMode as IM
import matplotlib.pyplot as plt
# Для логирования метрик и функций потерь в ходе обучения
from torch.utils.tensorboard import SummaryWriter
# Для удобной работы с обучающей/тестовой выборкой
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
# Прочее
import numpy as np
from tqdm.auto import tqdm
/home/r.fazylov/anaconda3/envs/pl_vl_170/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm
Для начала работы с данными требуется выполнить следующие пункты:
Определиться со способом хранения/чтения данных с диска. В задачах комьютерного зрения датасеты, как правило, имеют большой размер, который не помещается в оперативную память. Поэтому предлагается несколько форматов хранения: HDF5, memory-mapped files и "сырой" вид, т.е хранение .jpg/.png файлов на диске. Классы с указанными способами уже описаны в файле utils.py. Вам предлагается лишь замерить скорость чтения данных для каждого из форматов, затем выбрать наиболее быстрый (чуть позже).
hdf5 позволяет разбивать массивы информации на chunks, которые организованы в виде B-деревьев. Это имеет смысл при чтении hyperslabs - многомерных срезов массива, которые несмежны в памяти (non-contiguous). По умолчанию, hdf5 хранит данные непрерывно (contiguous)Memory-mapping файлов в оперативную память позволяет пропустить этап буфферизации, тем самым пропуская операцию копирования, лениво загружая информацию напрямую. Особенность этого подхода в том, что алгоритмически Best case скорости чтения достигается на непрерывном блоке информации (contiguous), а Worst case - наоборот, на несмежном в памяти (non-contiguous) блоке (на порядки хуже, чем потенциально возможно в hdf5).Привести все пары (изображение, маска) к единому размеру target_shape, указанному далее в словаре конфигурации default_config. Предлагается следующая последовательность действий:
transforms.Resize: при целочисленном аргументе size наименьшая сторона входного изображения будет интерполирована до size, а другая сторона (наибольшая) до размера size * aspect_ratio, т.е сохраняя соотношение сторон aspect_ratiotarget_shape. Возможны два случая: оставшаяся сторона меньше или больше (случай с равенством можно свести к ситуации "меньше на 0") требуемого размера. В первом случае будем дополнять изображение пикселями со значением pad_value при помощи transforms.Pad, а во втором - обрезать изображение при помощи transforms.CenterCrop.Последовательное исполнение операций модуля
transformsможно выполнить при помощи transforms.Compose.
- Ответить на вопрос:
А зачем, вообще, требуется сводить все изображения к одному размеру?
Ваш ответ: чтобы стакать их в батчи и хранить батчи как тензор [N, C, W, H]
from utils import *
def resize(img: Image, target_shape: tuple[int, int], pad_val: int) -> np.array:
"""
Приводит входное изображение `img` к размеру `target_shape`, указанной выше
последовательностью действий. Предполагается, что требуемый размер `target_shape` "квадратный"
"""
# Проверяем равенство желаемых размеров сторон изображения
assert target_shape[0] == target_shape[1]
# Необходимо для "универсальности" определения количества недостающих padding-пикселей.
# В случае если размерность текущего изображения меньше требуемой - имеем неотрицательное
# Количество недостающих padding-пикселей, relu возвращает это число без изменений.
# В случае если размерность текущего изображения больше требуемой - имеем отрицательное
# Количество недостающих padding-пикселей, т.е в padding пикселях мы не нуждаемся и
# relu возвращает значение 0.
def relu(x):
return x * (x > 0)
# Масштабируем наименьшую размерность `img` под `target_shape`
# В качестве способа интерполяции выберем интерполяцию методом ближайшего соседа
# Это необходимо для сохранения множества значений маски сегментации
img = transforms.Resize(target_shape[0], interpolation=IM.NEAREST)(img)
# Вычисляем количество недостающих padding-пикселей для каждой из сторон изображения
h, w = img.size[0], img.size[1]
h_diff, w_diff = target_shape[0] - h, target_shape[1] - w
pad_left, pad_right = w_diff // 2, w_diff - w_diff // 2
pad_top, pad_bottom = h_diff // 2, h_diff - h_diff // 2
resize_transform = transforms.Compose([
# Добавляем padding-пиксели. Если их нет, то операция Pad ничего не изменит (случай "больше").
transforms.Pad(
(relu(pad_left), relu(pad_top), relu(pad_right), relu(pad_bottom)),
fill=pad_val, padding_mode="constant"
),
# Обрезаем "лишние" пиксели. Если их нет, то CenterCrop ничего не изменит (случай "меньше").
transforms.CenterCrop(target_shape),
# Преобразуем PIL.Image изображение в массив np.array
transforms.Lambda(lambda x: np.array(x))
])
return resize_transform(img)
def prepare_dataset(config: dict, storage_class: Type[storage_class]):
"""
Предобрабатывает датасет и эффективно его сохраняет на диск
"""
with open(config["annotation_file"]) as f:
lines = f.readlines()
# Заводим массивы для блоков изображений, помещаемых в память
input_chunk = np.empty((config["chunk_size"], *config["target_shape"], 3), dtype=np.uint8)
target_chunk = np.empty((config["chunk_size"], *config["target_shape"]), dtype=np.uint8)
# Делим датасет на блоки
config["dataset_size"] = len(lines)
num_chunks = config["dataset_size"] // config["chunk_size"] + bool(config["dataset_size"] % config["chunk_size"])
dataset = storage_class(config)
# Читаем изображения с диска, предобрабатываем и сохраняем в выбранный нами формат
for chunk_idx in tqdm(range(num_chunks)):
for pos in range(config["chunk_size"]):
flat_idx = chunk_idx * config["chunk_size"] + pos
if (flat_idx >= config["dataset_size"]):
break
img_name, label = lines[flat_idx].rstrip("\n").split(' ')
input_raw = Image.open(os.path.join(config["input_dir"], img_name + ".jpg")).convert("RGB")
target_raw = Image.open(os.path.join(config["target_dir"], img_name + ".png")).convert('L')
input_chunk[pos] = resize(input_raw, config["target_shape"], 0)
target_chunk[pos] = renumerate_target(resize(target_raw, config["target_shape"], 2), int(label))
dataset.append(input_chunk, target_chunk)
dataset.lock()
return dataset
Для простоты будем выбирать размер изображений target_shape с одинаковыми сторонами. Предлагается использовать размер 256x256, хотя выбор за вами. Обратите внимание, что от размера изображений зависит быстродействие дальнейшего кода (чем больше картинки, тем дольше обучать).
# Конфигурация датасета
default_config = {
"input_dir": "SegTask/images",
"target_dir": "SegTask/seg_masks",
"target_shape": (256, 256), # Можно любой другой размер картинки
"chunk_size": 512, # количество изображений в блоке, загружаемых в оперативную память
}
# Конфигурации обучающей и тестовой выборок отличаются файлов аннотации
config_train = {"annotation_file": "SegTask/trainval.txt"} | default_config
config_test = {"annotation_file": "SegTask/test.txt"} | default_config
train_data_hdf5 = prepare_dataset(config_train, storage_hdf5)
train_data_memmap = prepare_dataset(config_train, storage_memmap)
train_data_raw = prepare_dataset(config_train, storage_raw)
100%|██████████| 8/8 [00:36<00:00, 4.52s/it] 100%|██████████| 8/8 [00:21<00:00, 2.68s/it] 100%|██████████| 8/8 [00:51<00:00, 6.50s/it]
Pytorch предоставляет нам удобные обертки Dataset и DataLoader для наших данных, которые эффективно нарезают наш датасет на batches (блоки) заданного размера, а также параллелизуют процесс чтения на num_workers нитей.
Также для дальнейшей работы нам понадобится аугментация данных. Ее цель заключается в еще большем расширении обучающей выборки путем применения преобразований над изображениями, которые изменяют их абсолютные значения пикселей, но не нарушают их информационное наполнение.
Например, преобразование ColorJitter способно изменить яркость изображения на случайное число, что не изменяет его контекст. Однако, преобразование RandomCrop не рекомендуется, посколько есть шанс, что мордочка животного не попадет в фото и класс животного будет неоднозначен. Таким образом, при каждом вызове объекта из обучающей выборки к нему будет применяться случайное преобразование/серия случайных преобразований. Обратите внимание, что преобразование изображения должно быть согласованным с его сегментационной маской.
Требуется реализовать предлагаемые ниже преобразования аугментации:
HorizontalFlip (0.25 балла)ColorJitter (0.25 балла)RandomPerspective (0.5 балла)Для каждого из указанных преобразований требуется написать магический метод __call__, который позволяет обращаться к объекту класса (преобразованию), как к функции (функтор из C++):
# инициализация
obj = Example()
# вызывается __call__
obj()
# Предлагается использовать эти функции
# Самому писать процедуры отражения картинки по вертикали/горизонтали или цветокоррекции не надо!
from torchvision.transforms.functional import hflip
from torchvision.transforms.functional import perspective
from torchvision.transforms import ColorJitter as CJ
from torchvision.transforms import RandomPerspective as RP
import random
class HorizontalFlip():
def __init__(self, prob: float):
self.p = prob
def __call__(self, pair: tuple[Image, Image]) -> tuple[Image, Image]:
"""
`pair` содержит пару (изображение, сегментационная маска)
"""
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
if random.random() > self.p:
return pair
return hflip(pair[0]), hflip(pair[1])
class ColorJitter():
def __init__(self, prob: float, param: tuple[float]):
self.p = prob
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
self.jitter = CJ(brightness=param[0], contrast=param[1], saturation=param[2])
def __call__(self, pair: tuple[Image, Image]) -> tuple[Image, Image]:
"""
`pair` содержит пару (изображение, сегментационная маска)
"""
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
if random.random() > self.p:
return pair
return self.jitter(pair[0]), pair[1]
class RandomPerspective():
def __init__(self, prob: float, param: float):
self.p = prob
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
self.distortion_scale = param
def __call__(self, pair: tuple[Image, Image]) -> tuple[Image, Image]:
"""
`pair` содержит пару (изображение, сегментационная маска)
"""
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
if random.random() > self.p:
return pair
return [perspective(img, *RP.get_params(*pair[0].size, self.distortion_scale), interpolation=IM.NEAREST) for img in pair]
Применим реализованные преобразования и убедимся в их работоспособности:
img_idx = np.random.randint(0, 100)
f, ax = plt.subplots(2, 4, figsize=(16, 8))
pair = train_data_hdf5[img_idx]
imgs2draw = {"Source": pair,
"HorizontalFlip": HorizontalFlip(1.0)(pair),
"ColorJitter": ColorJitter(1.0, (0.4, 0.4, 0.4))(pair),
"RandomPerspective": RandomPerspective(1.0, 0.25)(pair)
}
for idx, (name, pair) in enumerate(imgs2draw.items()):
ax[0, idx].imshow(pair[0])
ax[0, idx].set_title(name, fontsize=20)
ax[1, idx].imshow(colorize(np.array(pair[1])))
plt.show()
/home/r.fazylov/anaconda3/envs/pl_vl_170/lib/python3.9/site-packages/torchvision/transforms/functional.py:594: UserWarning: torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release. torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in the returned tuple (although it returns other information about the problem). To get the qr decomposition consider using torch.linalg.qr. The returned solution in torch.lstsq stored the residuals of the solution in the last m - n columns of the returned value whenever m > n. In torch.linalg.lstsq, the residuals in the field 'residuals' of the returned named tuple. The unpacking of the solution, as in X, _ = torch.lstsq(B, A).solution[:A.size(1)] should be replaced with X = torch.linalg.lstsq(A, B).solution (Triggered internally at /pytorch/aten/src/ATen/LegacyTHFunctionsCPU.cpp:389.) res = torch.lstsq(b_matrix, a_matrix)[0]
Далее описываем наш класс SegmentationData и операции приведения изображений типа PIL.Image к pytorch тензорам с ImageNet нормализацией. ImageNet нормализация - это частный случай Standard normalization, в котором поканальное среднее (цветовые каналы red, green, blue) и поканальное среднеквадратическое отклонение вычислены на огромной выборке изображений.
Ответьте на вопрос: А для чего нужно применять нормализацию к изображениям?
Ваш ответ: сверточный слой, так же как и линейный, чувствителен к масштабу входных значений. Обучившись на одном распределении значений пикселей (масштабе), сеть успешно будет работать только для изображений из этого же масштаба. Поэтому мы и нормализуем входные изображения. Внутри сети за нормализацию отвечают уже нормализационные слои (BatchNorm, LayerNorm, etc.).
Кроме того, это еще и вычислительно стабильнее, благодаря распределению значений от -1 до 1 (mean=0, std=1)
# Определяем устройство для вычислений (!желательно GPU!)
DEVICE = 'cuda:1' if torch.cuda.is_available() else 'cpu'
t_dict = {
"forward_input": transforms.Compose([
transforms.PILToTensor(),
transforms.Lambda(lambda x: x.float().to(DEVICE)/255.0),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]),
"backward_input": transforms.Compose([
transforms.Normalize(mean=[0.0, 0.0, 0.0],
std=[1./0.229, 1./0.224, 1./0.225]),
transforms.Normalize(mean=[-0.485, -0.456, -0.406],
std=[1.0, 1.0, 1.0]),
transforms.Lambda(lambda x: x.permute(1, 2, 0).cpu().numpy())
]),
"forward_target": transforms.Compose([
transforms.PILToTensor(),
transforms.Lambda(lambda x: x.long().squeeze().to(DEVICE)),
]),
"backward_target": transforms.Compose([
transforms.Lambda(lambda x: x.cpu().numpy())
]),
"augment": transforms.Compose([
HorizontalFlip(0.5),
ColorJitter(0.5, (0.4, 0.4, 0.4)),
RandomPerspective(0.5, 0.25)
]),
}
class SegmentationDataset(Dataset):
def __init__(self, dataset_raw: Type[storage_class], transforms: dict, train_flag: bool = True):
"""
Наследуем весь функционал из `Dataset` для наших данных `dataset_raw`
`transforms` содержит преобразования PIL.Image <-> torch.tensor и аугментации
`train_flag` регулирует аугментацию данных (для тестовой выборки она не нужна)
"""
super().__init__()
self.dataset_raw = dataset_raw
self.transforms = transforms
self.train_flag = train_flag
def __len__(self):
return self.dataset_raw.dataset_size
def __getitem__(self, idx: int) -> tuple[Image, Image]:
input, target = self.dataset_raw[idx]
if (self.train_flag):
input, target = self.transforms["augment"]((input, target))
return self.transforms["forward_input"](input), self.transforms["forward_target"](target)
from torch.utils.data import random_split
# Разделяем обучающую выборку на обучающую и валидационную
def split_train_val(train_data: Type[storage_class], train_portion: float = 0.8):
"""
`train_data` предобработанные данные
`train_portion` доля объектов, которая будет приходиться на обучающую выборку
"""
trainval_dataset = SegmentationDataset(train_data, t_dict, train_flag=True)
train_size = int(len(trainval_dataset) * train_portion)
val_size = len(trainval_dataset) - train_size
return random_split(trainval_dataset, [train_size, val_size])
# train_dataset_hdf5, val_dataset_hdf5 = split_train_val(train_data_hdf5)
train_dataset_memmap, val_dataset_memmap = split_train_val(train_data_memmap)
# train_dataset_raw, val_dataset_raw = split_train_val(train_data_raw)
from torch.utils.data import random_split
# Разделяем обучающую выборку на обучающую и валидационную
def split_train_val(train_data: Type[storage_class], train_portion: float = 0.8):
"""
`train_data` предобработанные данные
`train_portion` доля объектов, которая будет приходиться на обучающую выборку
"""
trainval_dataset = SegmentationDataset(train_data, t_dict, train_flag=True)
train_size = int(len(trainval_dataset) * train_portion)
val_size = len(trainval_dataset) - train_size
return random_split(trainval_dataset, [train_size, val_size])
train_dataset_hdf5, val_dataset_hdf5 = split_train_val(train_data_hdf5)
train_dataset_memmap, val_dataset_memmap = split_train_val(train_data_memmap)
train_dataset_raw, val_dataset_raw = split_train_val(train_data_raw)
Отрисуем случайное изображение (после применения случайных преобразований аугментации):
img_idx = np.random.randint(0, 100)
draw(train_dataset_hdf5[img_idx], t_dict);
/home/r.fazylov/anaconda3/envs/pl_vl_170/lib/python3.9/site-packages/torchvision/transforms/functional.py:165: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.) img = torch.as_tensor(np.asarray(pic))
dataloader_config = {
"batch_size": 16, # Ваше значение
"shuffle": True,
"num_workers": 0
}
train_dataloader_hdf5 = DataLoader(train_dataset_hdf5, **dataloader_config)
val_dataloader_hdf5 = DataLoader(val_dataset_hdf5, **dataloader_config)
train_dataloader_memmap = DataLoader(train_dataset_memmap, **dataloader_config)
val_dataloader_memmap = DataLoader(val_dataset_memmap, **dataloader_config)
train_dataloader_raw = DataLoader(train_dataset_raw, **dataloader_config)
val_dataloader_raw = DataLoader(val_dataset_raw, **dataloader_config)
Замерьте время чтения нашего датасета для каждого из форматов хранения:
def speedtest(dataloader: Type[DataLoader]) -> None:
for batch in dataloader:
pass
%timeit speedtest(train_dataloader_hdf5)
8.23 s ± 156 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit speedtest(train_dataloader_memmap)
9.48 s ± 164 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit speedtest(train_dataloader_raw)
10.8 s ± 208 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Ответьте на вопрос: Какой формат оказался самым эффективным по скорости? Почему?
Ваш ответ: самым эффективным оказался hdf5 благодаря тому что мы рандомно обращаемся к данным, то есть непоследовательно (shuffle=True). А он как раз и предназначен для этого
Создайте тестовый Dataloader победившего по скорости формата.
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
test_data_memmap = prepare_dataset(config_test, storage_memmap)
test_ds = SegmentationDataset(test_data_memmap, t_dict, train_flag=False)
test_dataloader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=0)
100%|██████████| 8/8 [02:44<00:00, 20.61s/it]
Ранее вас познакомили с архитектурой Unet - сверточными автокодировщиком, применяемом в области сегментации изображений. В данном задании мы разберем более продвинутую архитектуру сети сегментации PSPNet. Отличительной особенностью этой сети является Pyramid Pooling Module, который, в отличие от Unet, позволяет учитывать глобальный контекст изображения при формировании признаков его локальных областей.
Рассмотрим предлагаемую архитектуру PSPNet-подобной сети:
В качестве кодировщика Encoder будем брать предобученную ResNeXt сеть. Будем его использовать для получения двух глубинных представлений нашего входого изображения x:
x_main - "среднее" промежуточное представление, компромисс между низкоуровневыми признаками (цвет, контуры объектов, штрихи) и высокоуровневыми признаками (абстрактные признаки, отражающие семантику изображения)x_supp - финальное представление, содержащее самые высокоуровневые признаки, в которых значительно утеряна информация о точном простанственном расположении объектовПодобное разбиение выхода на 2 потока объясняется необходимостью в закодированной информации о пространственном расположении объектов (x_main) и вспомогательной информации о семантике всего изображения в целом (x_supp) для задачи семантической сегментации. Мы не можем себе позволить использовать лишь выход x_supp, как это делается, например, в задачах классификации, ведь от нас требуется дополнительное знание о расположении этого объекта на изображении.
Ваша задача состоит в написании декодировщика Decoder, а именно в написании блоков:
Pyramid Pooling Module. К входному тензору x_main параллельно применяется несколько операций пулинга разных размеров, которые сводят пространственные размерности исходного тензора до размеров 1x1, 2x2, 3x3 и 6x6. Каналы промежуточных тензоров эффективно редуцируются (при помощи nn.Conv2d c размером фильтра 1x1), а затем пространственные размерности интерполируются до исходных размеров. Эта процедура необходима для извлечения глобального контекста разных масштабов, которого не хватает классическим сверточным нейронным сетям (локальный контекст в пределах размера фильтра). Таким образом, выходной тензор, полученный конкатенацией этих глобальных контекстов, содержит информацию о всем входном тензоре с разными уровнями детализации. Промежуточное редуцирование каналов тензоров производится для сжатия информации, а также для индивидуального взвешивания глобального контекста каждого масштаба. Требуется реализовать forward этап этого блока. Для уточнения информации можно обратиться к статье.Supplementary Module осуществляет нелинейное преобразование над входным тензором x_supp с понижением числа каналов до размерности выхода модуля Pyramid Pooling Module. Вариант архитектуры этого преобразования (композиции слоев) уже предложен, но, при желании, вы можете с ним экспериментироватьUpsample Module осуществляет нелинейные преобразования над входным тензором с понижением числа каналов, которые чередуются с интерполяцией пространственных размерностей в 2 раза (увеличение). Таким образом, выход этого блока имеет ту же пространственную размерность, что и входное в кодировщик изображение. Это преобразование (слои Layer 0, Layer 1 и Layer 2) требуется экспериментально подобратьSegmentation Head нелинейно преобразует входной тензор в тензор score'ов. Имеем, что выходной тензор для каждого пикселя имеет num_classes score'ов (в нашем случае 3). В дальнейшем, индекс максимального score'а для заданного пикселя и будет его меткой класса (0, 1 или 2). Это преобразование (композиция слоев) требуется экспериментально подобратьЕсли декодировщик получается слишком тяжелый, то оператор Concat можно заменить на поканальное суммирование. Обратите внимание, что нет единственной правильной архитектуры указанных выше блоков. Требуется ее экспериментально подобрать так, чтобы получить наилучшее качество сегментации за разумную сложность.
from torchvision.models.resnet import ResNet
from torchvision.models import resnext50_32x4d
pretrained_model = resnext50_32x4d(pretrained=True)
# Выставляем evaluation mode (влияет на поведение таких слоев как BatchNorm2d, Dropout)
pretrained_model.eval()
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
Так как кодировщик используется предобученный, то требуется зафиксировать (заморозить) веса, чтобы по ним не тек градиент. Этим мы гарантируем, что кодировщик не изменяется в ходе обучения автокодировщика, а также экономим вычислительные ресурсы (граф градиента кодировщика не строится).
class EncoderBlock(nn.Module):
def __init__(self, pretrained_model: Type[ResNet]):
"""
Извлекает предобученные именованные слои кодировщика `pretrained_model`
Разделяет слои на `main` и `supp` потоки (см. архитектуру выше)
Вход: тензор (Batch_size, 3, Height, Width)
Выход: x_main тензор (Batch_size, 512, Height // 8, Width // 8)
Выход: x_supp тензор (Batch_size, 2048, Height // 32, Width // 32)
"""
super().__init__()
self.encoder_main = nn.Sequential()
for name, child in list(pretrained_model.named_children())[:-4]:
print(f"Pretrained main module {name} is loaded")
self.encoder_main.add_module(name, child)
self.encoder_supp = nn.Sequential()
for name, child in list(pretrained_model.named_children())[-4:-2]:
print(f"Pretrained supp module {name} is loaded")
self.encoder_supp.add_module(name, child)
def freeze(self) -> None:
"""
Замораживает веса кодировщика
"""
for p in self.parameters():
p.requires_grad = False
self.eval()
def unfreeze(self) -> None:
"""
Размораживает веса кодировщика
"""
for p in self.parameters():
p.requires_grad = True
self.train()
def forward(self, x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
x_main = self.encoder_main(x)
x_supp = self.encoder_supp(x_main)
return x_main, x_supp
encoder = EncoderBlock(pretrained_model)
Pretrained main module conv1 is loaded Pretrained main module bn1 is loaded Pretrained main module relu is loaded Pretrained main module maxpool is loaded Pretrained main module layer1 is loaded Pretrained main module layer2 is loaded Pretrained supp module layer3 is loaded Pretrained supp module layer4 is loaded
Для оценки сложности модели нам понадобится функция подсчета числа ее параметров, для этого используйте метод .parameters(). Реализуйте ее ниже:
Также полезно было бы убедиться, что метод .parameters() возвращает то, что мы от него ожидаем. Для этого воспользуемся методом .named_parameters():
for name, parameter in encoder.named_parameters():
print(name)
encoder_main.conv1.weight encoder_main.bn1.weight encoder_main.bn1.bias encoder_main.layer1.0.conv1.weight encoder_main.layer1.0.bn1.weight encoder_main.layer1.0.bn1.bias encoder_main.layer1.0.conv2.weight encoder_main.layer1.0.bn2.weight encoder_main.layer1.0.bn2.bias encoder_main.layer1.0.conv3.weight encoder_main.layer1.0.bn3.weight encoder_main.layer1.0.bn3.bias encoder_main.layer1.0.downsample.0.weight encoder_main.layer1.0.downsample.1.weight encoder_main.layer1.0.downsample.1.bias encoder_main.layer1.1.conv1.weight encoder_main.layer1.1.bn1.weight encoder_main.layer1.1.bn1.bias encoder_main.layer1.1.conv2.weight encoder_main.layer1.1.bn2.weight encoder_main.layer1.1.bn2.bias encoder_main.layer1.1.conv3.weight encoder_main.layer1.1.bn3.weight encoder_main.layer1.1.bn3.bias encoder_main.layer1.2.conv1.weight encoder_main.layer1.2.bn1.weight encoder_main.layer1.2.bn1.bias encoder_main.layer1.2.conv2.weight encoder_main.layer1.2.bn2.weight encoder_main.layer1.2.bn2.bias encoder_main.layer1.2.conv3.weight encoder_main.layer1.2.bn3.weight encoder_main.layer1.2.bn3.bias encoder_main.layer2.0.conv1.weight encoder_main.layer2.0.bn1.weight encoder_main.layer2.0.bn1.bias encoder_main.layer2.0.conv2.weight encoder_main.layer2.0.bn2.weight encoder_main.layer2.0.bn2.bias encoder_main.layer2.0.conv3.weight encoder_main.layer2.0.bn3.weight encoder_main.layer2.0.bn3.bias encoder_main.layer2.0.downsample.0.weight encoder_main.layer2.0.downsample.1.weight encoder_main.layer2.0.downsample.1.bias encoder_main.layer2.1.conv1.weight encoder_main.layer2.1.bn1.weight encoder_main.layer2.1.bn1.bias encoder_main.layer2.1.conv2.weight encoder_main.layer2.1.bn2.weight encoder_main.layer2.1.bn2.bias encoder_main.layer2.1.conv3.weight encoder_main.layer2.1.bn3.weight encoder_main.layer2.1.bn3.bias encoder_main.layer2.2.conv1.weight encoder_main.layer2.2.bn1.weight encoder_main.layer2.2.bn1.bias encoder_main.layer2.2.conv2.weight encoder_main.layer2.2.bn2.weight encoder_main.layer2.2.bn2.bias encoder_main.layer2.2.conv3.weight encoder_main.layer2.2.bn3.weight encoder_main.layer2.2.bn3.bias encoder_main.layer2.3.conv1.weight encoder_main.layer2.3.bn1.weight encoder_main.layer2.3.bn1.bias encoder_main.layer2.3.conv2.weight encoder_main.layer2.3.bn2.weight encoder_main.layer2.3.bn2.bias encoder_main.layer2.3.conv3.weight encoder_main.layer2.3.bn3.weight encoder_main.layer2.3.bn3.bias encoder_supp.layer3.0.conv1.weight encoder_supp.layer3.0.bn1.weight encoder_supp.layer3.0.bn1.bias encoder_supp.layer3.0.conv2.weight encoder_supp.layer3.0.bn2.weight encoder_supp.layer3.0.bn2.bias encoder_supp.layer3.0.conv3.weight encoder_supp.layer3.0.bn3.weight encoder_supp.layer3.0.bn3.bias encoder_supp.layer3.0.downsample.0.weight encoder_supp.layer3.0.downsample.1.weight encoder_supp.layer3.0.downsample.1.bias encoder_supp.layer3.1.conv1.weight encoder_supp.layer3.1.bn1.weight encoder_supp.layer3.1.bn1.bias encoder_supp.layer3.1.conv2.weight encoder_supp.layer3.1.bn2.weight encoder_supp.layer3.1.bn2.bias encoder_supp.layer3.1.conv3.weight encoder_supp.layer3.1.bn3.weight encoder_supp.layer3.1.bn3.bias encoder_supp.layer3.2.conv1.weight encoder_supp.layer3.2.bn1.weight encoder_supp.layer3.2.bn1.bias encoder_supp.layer3.2.conv2.weight encoder_supp.layer3.2.bn2.weight encoder_supp.layer3.2.bn2.bias encoder_supp.layer3.2.conv3.weight encoder_supp.layer3.2.bn3.weight encoder_supp.layer3.2.bn3.bias encoder_supp.layer3.3.conv1.weight encoder_supp.layer3.3.bn1.weight encoder_supp.layer3.3.bn1.bias encoder_supp.layer3.3.conv2.weight encoder_supp.layer3.3.bn2.weight encoder_supp.layer3.3.bn2.bias encoder_supp.layer3.3.conv3.weight encoder_supp.layer3.3.bn3.weight encoder_supp.layer3.3.bn3.bias encoder_supp.layer3.4.conv1.weight encoder_supp.layer3.4.bn1.weight encoder_supp.layer3.4.bn1.bias encoder_supp.layer3.4.conv2.weight encoder_supp.layer3.4.bn2.weight encoder_supp.layer3.4.bn2.bias encoder_supp.layer3.4.conv3.weight encoder_supp.layer3.4.bn3.weight encoder_supp.layer3.4.bn3.bias encoder_supp.layer3.5.conv1.weight encoder_supp.layer3.5.bn1.weight encoder_supp.layer3.5.bn1.bias encoder_supp.layer3.5.conv2.weight encoder_supp.layer3.5.bn2.weight encoder_supp.layer3.5.bn2.bias encoder_supp.layer3.5.conv3.weight encoder_supp.layer3.5.bn3.weight encoder_supp.layer3.5.bn3.bias encoder_supp.layer4.0.conv1.weight encoder_supp.layer4.0.bn1.weight encoder_supp.layer4.0.bn1.bias encoder_supp.layer4.0.conv2.weight encoder_supp.layer4.0.bn2.weight encoder_supp.layer4.0.bn2.bias encoder_supp.layer4.0.conv3.weight encoder_supp.layer4.0.bn3.weight encoder_supp.layer4.0.bn3.bias encoder_supp.layer4.0.downsample.0.weight encoder_supp.layer4.0.downsample.1.weight encoder_supp.layer4.0.downsample.1.bias encoder_supp.layer4.1.conv1.weight encoder_supp.layer4.1.bn1.weight encoder_supp.layer4.1.bn1.bias encoder_supp.layer4.1.conv2.weight encoder_supp.layer4.1.bn2.weight encoder_supp.layer4.1.bn2.bias encoder_supp.layer4.1.conv3.weight encoder_supp.layer4.1.bn3.weight encoder_supp.layer4.1.bn3.bias encoder_supp.layer4.2.conv1.weight encoder_supp.layer4.2.bn1.weight encoder_supp.layer4.2.bn1.bias encoder_supp.layer4.2.conv2.weight encoder_supp.layer4.2.bn2.weight encoder_supp.layer4.2.bn2.bias encoder_supp.layer4.2.conv3.weight encoder_supp.layer4.2.bn3.weight encoder_supp.layer4.2.bn3.bias
def count_parameters(model: Type[nn.Module]) -> int:
"""
Считает число весов в модели `model`, для которых требуется градиент
"""
total_params = 0
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
for name, param in model.named_parameters():
if param.requires_grad:
total_params += param.numel()
return total_params
Убедимся, что метод .freeze() успешно замораживает веса:
print("Encoder #parameters before freeze():", count_parameters(encoder))
encoder.freeze()
print("Encoder #parameters after freeze():", count_parameters(encoder))
Encoder #parameters before freeze(): 22979904 Encoder #parameters after freeze(): 0
Реализуйте PyramidPoolingModule, Upsample и SegmentationHead (по 0.5 балла), а также заполните пропущенные значения ??? в UpsampleModule и DecoderBlock (по 0.25 балла). Выбор параметров/архитектуры сети в большей степени зависит от результатов обучения в следующей части задания (так что вы еще вернетесь к этому пункту).
class PyramidPoolingModule(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bin_sizes: tuple[int, ...]):
"""
Вход: тензор (Batch_size, `in_channels`, Height, Width)
`bin_sizes` - пространственные размерности для каждой пулинг операции
Пример: bin_sizes = (1, 2, 3, 6).
Выход: тензор (Batch_size, `in_channels` + len(`bin_sizes`) * `out_channels`, Height, Width)
"""
super().__init__()
self.bins = []
for bin_size in bin_sizes:
self.bins.append(nn.Sequential(
nn.AdaptiveAvgPool2d(bin_size),
nn.Conv2d(in_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
))
self.bins = nn.ModuleList(self.bins)
def forward(self, x: torch.tensor) -> torch.tensor:
h, w = x.shape[2:]
out = []
"""
Осуществите все пулинг-операции с последующим `Upscale`
Подсказка: используйте torch.functional.interpolate
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
for bin in self.bins:
out.append(F.interpolate(bin(x), size=(h, w), mode='bilinear', align_corners=True))
return torch.cat([x] + out, dim=1)
class SupplementaryModule(nn.Module):
def __init__(self, in_channels: int, out_channels: int, dropout: float):
"""
Вход: тензор (Batch_size, `in_channels`, Height, Width)
Выход: тензор (Batch_size, `out_channels`, Height, Width)
"""
super().__init__()
mid_channels = 512 # TODO: in_channels // 2 (1024), 512 / 256
self.suppl = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Dropout2d(p=dropout), # TODO: убрать dropout
nn.Conv2d(mid_channels, out_channels, kernel_size=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
"""
Указанную выше архитектуру можно менять по своему усмотрению
"""
def forward(self, x: torch.tensor) -> torch.tensor:
return self.suppl(x)
class Upsample(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
"""
Вход: тензор (Batch_size, `in_channels`, Height, Width)
Выход: тензор (Batch_size, `out_channels`, 2 * Height, 2 * Width)
"""
super().__init__()
self.us_transform = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), # TODO: 2xConv2d / ConvTransposed2d + 1xConv2d
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x: torch.tensor) -> torch.tensor:
"""
Подсказка: используйте torch.functional.interpolate для удвоения пространственных размерностей
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
w, h = x.shape[2:]
x = self.us_transform(x)
x = F.interpolate(x, size=(2*w, 2*h), mode='bilinear', align_corners=True)
return x
class UpsampleModule(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
"""
Вход: тензор (Batch_size, `in_channels`, Height, Width)
Выход: тензор (Batch_size, `out_channels`, 8 * Height, 8 * Width)
"""
super().__init__()
m1_channels = in_channels // 2 # TODO: так как входных каналов мало (PPM_out_channels + supp_out_channels)
m2_channels = m1_channels // 2
self.upsample = nn.Sequential(
Upsample(in_channels, m1_channels),
Upsample(m1_channels, m2_channels),
Upsample(m2_channels, out_channels)
)
def forward(self, x: torch.tensor) -> torch.tensor:
return self.upsample(x)
class DecoderBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int, bin_sizes: tuple[int, ...], dropout: float = 0.1):
"""
Вход x_main: тензор (Batch_size, `in_channels`, Height, Width)
Вход x_supp: тензор (Batch_size, 4 * `in_channels`, Height // 4, Width // 4)
Выход: тензор (Batch_size, `out_channels`, 8 * Height, 8 * Width)
"""
super().__init__()
assert in_channels % len(bin_sizes) == 0 # in_channels = 512
bin_out_channels = 1 # в статье 1
PPM_out_channels = len(bin_sizes) * bin_out_channels + in_channels # 516
supp_out_channels = 32 # абстрактная информация о классе, не очень нужна имхо, 64 / 32 хватит
self.PPM = PyramidPoolingModule(in_channels, bin_out_channels, bin_sizes)
self.SM = SupplementaryModule(4 * in_channels, supp_out_channels, dropout)
self.UM = UpsampleModule(PPM_out_channels + supp_out_channels, out_channels)
def forward(self, x_main: torch.tensor, x_supp: torch.tensor) -> torch.tensor:
h_supp, w_supp = x_supp.shape[2:]
x_supp = F.interpolate(input=x_supp, size=(4 * h_supp, 4 * w_supp), mode='bilinear', align_corners=True)
x_supp = self.SM(x_supp)
x_main = self.PPM(x_main)
out = self.UM(torch.cat([x_main, x_supp], dim=1))
return out
class SegmentationHead(nn.Module):
def __init__(self, in_channels: int, num_classes: int, dropout: float = 0.0):
"""
Вычисляет score для каждого из классов
Вход: тензор (Batch_size, `in_channels`, Height, Width)
Выход: тензор (Batch_size, `num_classes`, Height, Width)
"""
super().__init__()
self.segmentation_head = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(in_channels // 2, num_classes, kernel_size=1),
)
def forward(self, x: torch.tensor, x_supp: torch.tensor) -> torch.tensor:
"""
На будущее зададим фиктивный аргумент `x_supp`, который пока не будем использовать
"""
return self.segmentation_head(x)
В задаче сегментации для оценки предсказательной способности нейронной сети, в основном, используют следующие метрики:
Пусть $\mathrm{P}$ обозначает прогноз сег. маски (Prediction), $\mathrm{S}$ обозначает score'ы для каждого класса сег. маски (Scores), а $\mathrm{T}$ означает сег. маску (Target). Тогда:
макро- или микро-усредним для них метрики. Требуется реализовать мультиклассовые варианты указанных метрик с поддержкой макро- и микро-усреднения (по 1 баллу). Обратите внимание, что метрики рассчитываются для каждого элемента из батча. За редуцирование метрик вдоль размерности батча отвечает аргумент reduce (см. ниже).Также для обучения будем использовать две разные, но схожие функции потерь:
Требуется реализовать обе функции потерь. Также всюду необходимо обеспечить корректную обработку значений ignore_index, которые в нашем случае равны 255 (не участвуют в расчете метрик/функций потерь). Если представители некоторых классов в $\mathrm{T}$ отсутствуют, то учитывать эти классы при макро-усреднении не нужно.
class MetricsCollection():
def __init__(self, num_classes: int, ignore_index: int = 255):
self.num_classes = num_classes
self.ignore_index = ignore_index
def IoUMetric(self, prediction: torch.tensor, target: torch.tensor, average: str = "macro", reduce: str = "mean") -> Union[torch.tensor, float]:
"""
`prediction` предсказанная сегментационная маска размера (Batch_size, Height, Width)
`target` истинная сегментационная маска размера (Batch_size, Height, Width)
`average` тип мультклассового усреднения
`reduce` редукция значений метрики вдоль размерности Batch; None - без редукции
"""
assert average in ["micro", "macro"]
assert reduce in ["sum", "mean", "none"]
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
ignore_idx = torch.where(target == self.ignore_index)
n_batch = prediction.shape[0]
nums = torch.zeros(n_batch, self.num_classes)
dens = torch.zeros(n_batch, self.num_classes)
dens_macro = torch.zeros(n_batch, self.num_classes)
classes_in_target = torch.ones(n_batch, 1) * self.num_classes
for c in range(self.num_classes):
pred_c = prediction.clone()
pred_c[prediction == c] = 1
pred_c[prediction != c] = 0
pred_c[ignore_idx] = 0
target_c = target.clone()
target_c[target == c] = 1
target_c[target != c] = 0
target_c[ignore_idx] = 0
nums[:, c] = (pred_c * target_c).sum(dim=(1, 2))
dens[:, c] = (pred_c + target_c - pred_c * target_c).sum(dim=(1, 2))
not_in_target = (target_c.sum(dim=(1,2)) == 0).view(-1)
classes_in_target[not_in_target] -= 1
dens_macro[:, c] = dens[:, c]
dens_macro[not_in_target, c] = 1 # prevent zero division
if average == 'micro':
iou = nums.sum(dim=1) / dens.sum(dim=1)
else:
iou = (nums / dens_macro)
iou = iou.sum(dim=1)
iou = iou / classes_in_target
if reduce == 'sum':
return iou.sum()
elif reduce == 'mean':
return iou.mean()
return iou
def RecallMetric(self, prediction: torch.tensor, target: torch.tensor, average: str = "macro", reduce: str = "mean") -> Union[torch.tensor, float]:
"""
`prediction` предсказанная сегментационная маска размера (Batch_size, Height, Width)
`target` истинная сегментационная маска размера (Batch_size, Height, Width)
`average` тип мультклассового усреднения
`reduce` редукция значений метрики вдоль размерности Batch; None - без редукции
"""
assert average in ["micro", "macro"]
assert reduce in ["sum", "mean", "none"]
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
ignore_idx = torch.where(target == self.ignore_index)
n_batch = prediction.shape[0]
nums = torch.zeros(n_batch, self.num_classes)
dens = torch.zeros(n_batch, self.num_classes)
dens_macro = torch.zeros(n_batch, self.num_classes)
classes_in_target = torch.ones(n_batch, 1) * self.num_classes
for c in range(self.num_classes):
pred_c = prediction.clone()
pred_c[prediction == c] = 1
pred_c[prediction != c] = 0
pred_c[ignore_idx] = 0
target_c = target.clone()
target_c[target == c] = 1
target_c[target != c] = 0
target_c[ignore_idx] = 0
nums[:, c] = (pred_c * target_c).sum(dim=(1, 2))
dens[:, c] = (target_c).sum(dim=(1, 2))
not_in_target = (target_c.sum(dim=(1,2)) == 0).view(-1)
classes_in_target[not_in_target] -= 1
dens_macro[:, c] = dens[:, c]
dens_macro[not_in_target, c] = 1 # prevent zero division
if average == 'micro':
iou = nums.sum(dim=1) / dens.sum(dim=1)
else:
iou = (nums / dens_macro)
iou = iou.sum(dim=1)
iou = iou / classes_in_target
if reduce == 'sum':
return iou.sum()
elif reduce == 'mean':
return iou.mean()
return iou
def FocalLoss(self, scores: torch.tensor, target: torch.tensor, reduce: str = "mean", gamma: float = 1.) -> Union[torch.tensor, float]:
"""
`scores` score'ы каждого класса сегментационной маски размера (Batch_size, num_classes, Height, Width)
`target` истинная сегментационная маска размера (Batch_size, Height, Width)
`reduce` редукция значений функции потерь вдоль размерности Batch; None - без редукции
"""
assert scores.shape[1] == self.num_classes
assert reduce in ["sum", "mean", "none"]
ce_loss = F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="none")
coef = (1 - torch.exp(-ce_loss))**gamma
focal_loss = coef * ce_loss
norm = (focal_loss.numel() - (target == self.ignore_index).sum())
if (reduce == "sum"):
return focal_loss.sum() / norm * scores.shape[0]
elif (reduce == "mean"):
return focal_loss.sum() / norm
else:
return focal_loss.sum(dim=[1, 2]) / norm * scores.shape[0]
def CrossEntropyLoss(self, scores: torch.tensor, target: torch.tensor, reduce: str = "mean") -> Union[torch.tensor, float]:
"""
`scores` score'ы каждого класса сегментационной маски размера (Batch_size, num_classes, Height, Width)
`target` истинная сегментационная маска размера (Batch_size, Height, Width)
`reduce` редукция значений функции потерь вдоль размерности Batch; None - без редукции
"""
assert scores.shape[1] == self.num_classes
assert reduce in ["sum", "mean", "none"]
if (reduce == "sum"):
return F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="mean") * scores.shape[0]
elif (reduce == "mean"):
return F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="mean")
else:
return F.cross_entropy(scores, target, ignore_index=self.ignore_index, reduction="none")
@classmethod
def ListMetrics(cls):
return [method for method in dir(cls) if (method.endswith("Metric"))]
@classmethod
def ListLosses(cls):
return [method for method in dir(cls) if (method.endswith("Loss"))]
metric_class = MetricsCollection(num_classes=3, ignore_index=255)
prediction = torch.tensor([[[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
[[0, 0, 0, 0], [0, 2, 2, 0], [0, 2, 0, 0], [0, 0, 0, 0]]])
target = torch.tensor([[[0, 0, 0, 0], [0, 1, 255, 0], [0, 1, 255, 0], [0, 0, 0, 0]],
[[0, 0, 0, 0], [0, 255, 2, 0], [0, 255, 2, 0], [0, 0, 0, 0]]])
assert np.isclose(metric_class.RecallMetric(prediction, target, "micro", "mean").item(), 0.9286, atol=1e-3)
assert np.isclose(metric_class.RecallMetric(prediction, target, "macro", "mean").item(), 0.7500, atol=1e-3)
assert np.isclose(metric_class.IoUMetric(prediction, target, "micro", "mean").item(), 0.8667, atol=1e-3)
assert np.isclose(metric_class.IoUMetric(prediction, target, "macro", "mean").item(), 0.7115, atol=1e-3)
Ответьте на вопрос (№1): Что говорит о предсказательной способности нашей сети ситуация: высокий Recall и низкий IoU для некоторого класса? Возможна ли обратная ситуация?
Ваш ответ: это говорит о том, что сеть предсказывает объект размера больше, чем он есть на самом деле. Обратная ситуация, низкий recall и высокий iou, быть не может, так как если маленький recall, то будет и малеьникй P * T (числитель IoU).
Ответьте на вопрос (№2): Какой вид усреднения правильней использовать в нашей задаче: макро и микро? Почему?
Ваш ответ: макро, так как пиксели-классы внутри одного изображения всегда будут несбалансированны. Макро лучше подходит для дисбаланса, ошибки на всех классах влияют на макро-метрику равнозначно. А микро же благодаря усреднению по всем семплам всех классов будет нивелировать возможные ошибки в маленьких классах.
Ответьте на вопрос (№3): В чем преимущество Focal Loss перед Cross Entropy Loss? Что контроллирует гиперпараметр 𝛾 в Focal Loss?
Ваш ответ: Focal Loss помогает при дисбалансе классов, это благодаря тому, что он хорошо штрафует за ошибки, при этом одинаково поощряет за уверенные ответы (скажем за если модель выдала p>=0.8, то ответ уверенный и при ground truth=1 поощряться такие ответы будут одинаково). Гиперпараметр \gamma отвечает как раз таки с какого p мы считаем, что ответ уверенный
Теперь осталось лишь собрать все написанное ранее воедино и обучить нашу сеть. Чтобы контроллировать процесс обучения нашей сети, будем вычислять усредненные метрики и функции потерь на валидационной выборке. Для удобства отображения информации воспользуемся инструментом tensorboard. Для этого заведем объект класса SummaryWriter, который создаст и откроет на запись специальный event файл для tensorboard. Для визуализации содержимого вводится команда tensorboard --logdir=<PATH> в терминале. Если возникла необходимость в мониториге нескольких tensorboard, то каждому из них требуется присвоить свой уникальный порт --port <PORT>. Пример использования tensorboard на Google Colab.
Требуется написать методы train_model и test_model. Вся конфигурация обучения хранится в словаре train_config. При желании его можно дополнить чем-то своим.
К вашему решению потребуется прикрепить логи tensorboard. Чтобы облегчить процедуру проверки настоятельно рекомендуется пользоваться inline-tensorboard:
%load_ext tensorboard
%tensorboard --logdir ./runs
class PSPNet(nn.Module):
def __init__(self, pretrained_model: Type[ResNet], HeadBlock: Type[nn.Module], num_classes: int, train_config: dict, bin_sizes: tuple[int, ...] = (1, 2, 3, 6)):
"""
`pretrained_model` модель предобученного кодировщика
`Head` класс блока, оценивающего score'ы для каждого класса сегментационной маски
`num_class` число классов сегментации
`train_config` словарь с конфигурацией процесса обучения сети
`bin_sizes` пространственные размеры к которым сводит пулинг в блоке PPM
"""
super().__init__()
self.encoder = EncoderBlock(pretrained_model)
self.encoder.freeze()
mid_channels = 256 # TODO: 256 / 128 / 64
self.decoder = DecoderBlock(512, mid_channels, bin_sizes)
self.head = HeadBlock(mid_channels, num_classes)
self.train_config = train_config
self.metric_class = train_config["metric_class"]
self.optimizer = train_config["optimizer"](self.parameters(), **train_config["optimizer_params"])
self.scheduler = train_config["scheduler"](self.optimizer, **train_config["scheduler_params"])
def forward(self, x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
# Для гарантии отсутствия градиентов по кодировщику
with torch.no_grad():
x_main, x_supp = self.encoder(x)
out = self.decoder(x_main, x_supp)
out = self.head(out, x_supp)
return out, torch.argmax(out.detach(), dim=1)
def write_val_metrics(self, val_metrics: dict, iter_num: int, norm: float = 1.0) -> None:
"""
Записывает усредненные значения метрик/функций потерь в tensorboard
`val_metrics` словарь с ключами "название_метрики/функции потерь" и их значениями
`iter_num` номер глобальной итерации (по формуле #всего_итераций * номер_эпохи + номер_итерации)
`norm` фактор нормализации; для усреднения равен числу объектов в валидационной выборке
"""
for method, value in val_metrics.items():
self.train_config["writer"].add_scalar(f"Mean {method}", np.round(val_metrics[method].item()/norm, 2), iter_num)
def validate_model(self, val_dataloader: Type[DataLoader], iter_num: int) -> None:
"""
Валидирует текущую модель и вычисляет соответствующие метрики/функции потерь
`val_dataloader` валидационная выборка
`iter_num` номер глобальной итерации (по формуле #всего_итераций * номер_эпохи + номер_итерации)
"""
# Выставляет декодировщик в режим валидации (влияет на поведение BatchNorm2d и Dropout)
self.decoder.eval()
# Инициализация словаря метрик/функций потерь
val_metrics = dict([(method, 0.0) for method in (self.metric_class.ListMetrics() + self.metric_class.ListLosses())])
# Обязательно считать с контекстным менеджером torch.no_grad()
# Даже если мы не делаем шаг оптимизации, мы экономим память (не считаем градиенты)
with torch.no_grad():
for input, target in val_dataloader:
scores, prediction = self.forward(input)
for metric in self.metric_class.ListMetrics():
val_metrics[metric] += getattr(self.metric_class, metric)(prediction, target, reduce="sum")
for loss in self.metric_class.ListLosses():
val_metrics[loss] += getattr(self.metric_class, loss)(scores, target, reduce="sum")
# Tensorboard также позволяет сохранять визуализацию наших предсказаний в ходе обучения
figure = draw((input[0], target[0]), t_dict, prediction[0], log=True)
self.train_config["writer"].add_figure("image/GT/prediction", figure, iter_num)
self.write_val_metrics(val_metrics, iter_num, norm=len(val_dataloader.dataset))
# Возвращает режим обучения декодировщика
self.decoder.train()
def train_model(self, train_dataloader: Type[DataLoader], val_dataloader: Type[DataLoader]) -> None:
"""
Обучает модель на обучающей выборке, периодически (периодичность выставляется в train_config) валидирует на валидационной выборке
В конце каждой эпохи сохраняет модель на диск
`train_dataloader` обучающая выборка
`val_dataloader` валидационная выборка
"""
# Выставляет режим обучения декодировщика
self.decoder.train()
for epoch in range(self.train_config["num_epochs"]):
for iter_num, (input, target) in enumerate(train_dataloader):
self.optimizer.zero_grad()
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
scores, pred = self.forward(input)
loss = self.train_config['loss_fn'](scores, target)
loss.backward()
self.optimizer.step()
self.scheduler.step()
if (iter_num % self.train_config["validate_each_iter"] == 0):
print(f"Epoch: {epoch+1}/{self.train_config['num_epochs']} || Iter: {iter_num}/{len(train_dataloader)} || Loss: {loss.item()}")
self.validate_model(val_dataloader, epoch * len(train_dataloader) + iter_num)
torch.save(self.state_dict(), self.train_config["save_model_path"] + f"_{epoch+1}.pth")
def test_model(self, test_dataloader: Type[DataLoader]) -> tuple[torch.tensor, torch.tensor]:
"""
Inference модели на тестовой выборке. Возвращает тензор предсказаний сег.масок и тензор истинных сег.масок
`test_dataloader` тестовая выборка
"""
# Выставляет декодировщик в режим валидации (влияет на поведение BatchNorm2d и Dropout)
self.decoder.eval()
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
dl_prediction = []
dl_target = []
with torch.no_grad():
for iter_num, (input, target) in enumerate(test_dataloader):
scores, pred = self.forward(input)
dl_prediction.append(pred.detach().cpu())
dl_target.append(target.detach().cpu())
dl_prediction = torch.cat(dl_prediction, dim=0)
dl_target = torch.cat(dl_target, dim=0)
return dl_prediction, dl_target
Вам приведены начальные значения гиперпараметров сети. Подберите гиперпараметры (если необходимо) и обучите сеть на обе функции потерь CrossEntropyLoss и FocalLoss. Добейтесь следующих результатов на тестовой выборке хотя бы для одной из них:
Mean IoU metric > 0.87Mean Recall metric > 0.96К вашему решению требуется прикрепить логи tensorboard.
from torch.optim.lr_scheduler import StepLR
train_config = {
"num_epochs": 5,
"optimizer": torch.optim.Adam,
"optimizer_params": {
"lr": 1e-3,
"weight_decay": 1e-5
},
"loss_fn": metric_class.FocalLoss, # or metric_class.FocalLoss
"scheduler": StepLR,
"scheduler_params": {
"step_size": 50,
"gamma": 0.85
},
"validate_each_iter": 10,
"writer": SummaryWriter(comment="Floss"), #Floss
"save_model_path": 'FLoss.pth',
"metric_class": metric_class
}
net = PSPNet(pretrained_model, SegmentationHead, num_classes=3, train_config=train_config).to(DEVICE)
print("#параметров в сети:", count_parameters(net))
Pretrained main module conv1 is loaded Pretrained main module bn1 is loaded Pretrained main module relu is loaded Pretrained main module maxpool is loaded Pretrained main module layer1 is loaded Pretrained main module layer2 is loaded Pretrained supp module layer3 is loaded Pretrained supp module layer4 is loaded #параметров в сети: 11759802
net.train_model(train_dataloader_memmap, val_dataloader_memmap)
Epoch: 1/5 || Iter: 0/184 || Loss: 0.809173047542572 Epoch: 1/5 || Iter: 10/184 || Loss: 0.4578273892402649 Epoch: 1/5 || Iter: 20/184 || Loss: 0.20597709715366364 Epoch: 1/5 || Iter: 30/184 || Loss: 0.1933986395597458 Epoch: 1/5 || Iter: 40/184 || Loss: 0.19391778111457825 Epoch: 1/5 || Iter: 50/184 || Loss: 0.13917164504528046 Epoch: 1/5 || Iter: 60/184 || Loss: 0.09061640501022339 Epoch: 1/5 || Iter: 70/184 || Loss: 0.10481450706720352 Epoch: 1/5 || Iter: 80/184 || Loss: 0.14566083252429962 Epoch: 1/5 || Iter: 90/184 || Loss: 0.11853332817554474 Epoch: 1/5 || Iter: 100/184 || Loss: 0.09860879927873611 Epoch: 1/5 || Iter: 110/184 || Loss: 0.12638220191001892 Epoch: 1/5 || Iter: 120/184 || Loss: 0.10887762904167175 Epoch: 1/5 || Iter: 130/184 || Loss: 0.08282797038555145 Epoch: 1/5 || Iter: 140/184 || Loss: 0.09824994951486588 Epoch: 1/5 || Iter: 150/184 || Loss: 0.09718429297208786 Epoch: 1/5 || Iter: 160/184 || Loss: 0.16996942460536957 Epoch: 1/5 || Iter: 170/184 || Loss: 0.12718182802200317 Epoch: 1/5 || Iter: 180/184 || Loss: 0.10510507225990295 Epoch: 2/5 || Iter: 0/184 || Loss: 0.08859036862850189 Epoch: 2/5 || Iter: 10/184 || Loss: 0.13131177425384521 Epoch: 2/5 || Iter: 20/184 || Loss: 0.11125364899635315 Epoch: 2/5 || Iter: 30/184 || Loss: 0.0709814578294754 Epoch: 2/5 || Iter: 40/184 || Loss: 0.17172373831272125 Epoch: 2/5 || Iter: 50/184 || Loss: 0.08666875213384628 Epoch: 2/5 || Iter: 60/184 || Loss: 0.11063261330127716 Epoch: 2/5 || Iter: 70/184 || Loss: 0.10458337515592575 Epoch: 2/5 || Iter: 80/184 || Loss: 0.1034105122089386 Epoch: 2/5 || Iter: 90/184 || Loss: 0.1510966271162033 Epoch: 2/5 || Iter: 100/184 || Loss: 0.09208666533231735 Epoch: 2/5 || Iter: 110/184 || Loss: 0.4372779130935669 Epoch: 2/5 || Iter: 120/184 || Loss: 0.13854968547821045 Epoch: 2/5 || Iter: 130/184 || Loss: 0.08269744366407394 Epoch: 2/5 || Iter: 140/184 || Loss: 0.0856635645031929 Epoch: 2/5 || Iter: 150/184 || Loss: 0.08284060657024384 Epoch: 2/5 || Iter: 160/184 || Loss: 0.08710094541311264 Epoch: 2/5 || Iter: 170/184 || Loss: 0.12266074120998383 Epoch: 2/5 || Iter: 180/184 || Loss: 0.11962353438138962 Epoch: 3/5 || Iter: 0/184 || Loss: 0.0773673728108406 Epoch: 3/5 || Iter: 10/184 || Loss: 0.07409332692623138 Epoch: 3/5 || Iter: 20/184 || Loss: 0.07445875555276871 Epoch: 3/5 || Iter: 30/184 || Loss: 0.08721039444208145 Epoch: 3/5 || Iter: 40/184 || Loss: 0.08603931963443756 Epoch: 3/5 || Iter: 50/184 || Loss: 0.1051332876086235 Epoch: 3/5 || Iter: 60/184 || Loss: 0.07602953910827637 Epoch: 3/5 || Iter: 70/184 || Loss: 0.0799580067396164 Epoch: 3/5 || Iter: 80/184 || Loss: 0.2258630096912384 Epoch: 3/5 || Iter: 90/184 || Loss: 0.06571513414382935 Epoch: 3/5 || Iter: 100/184 || Loss: 0.0787486881017685 Epoch: 3/5 || Iter: 110/184 || Loss: 0.08627845346927643 Epoch: 3/5 || Iter: 120/184 || Loss: 0.044554226100444794 Epoch: 3/5 || Iter: 130/184 || Loss: 0.07834852486848831 Epoch: 3/5 || Iter: 140/184 || Loss: 0.11499042063951492 Epoch: 3/5 || Iter: 150/184 || Loss: 0.06482337415218353 Epoch: 3/5 || Iter: 160/184 || Loss: 0.09408998489379883 Epoch: 3/5 || Iter: 170/184 || Loss: 0.09243714064359665 Epoch: 3/5 || Iter: 180/184 || Loss: 0.0742005854845047 Epoch: 4/5 || Iter: 0/184 || Loss: 0.0555943064391613 Epoch: 4/5 || Iter: 10/184 || Loss: 0.07216621190309525 Epoch: 4/5 || Iter: 20/184 || Loss: 0.07730989903211594 Epoch: 4/5 || Iter: 30/184 || Loss: 0.14474773406982422 Epoch: 4/5 || Iter: 40/184 || Loss: 0.06566669791936874 Epoch: 4/5 || Iter: 50/184 || Loss: 0.07082650810480118 Epoch: 4/5 || Iter: 60/184 || Loss: 0.07712826132774353 Epoch: 4/5 || Iter: 70/184 || Loss: 0.08188728988170624 Epoch: 4/5 || Iter: 80/184 || Loss: 0.05112256482243538 Epoch: 4/5 || Iter: 90/184 || Loss: 0.07622982561588287 Epoch: 4/5 || Iter: 100/184 || Loss: 0.05635732039809227 Epoch: 4/5 || Iter: 110/184 || Loss: 0.08246149122714996 Epoch: 4/5 || Iter: 120/184 || Loss: 0.06526417285203934 Epoch: 4/5 || Iter: 130/184 || Loss: 0.07732726633548737 Epoch: 4/5 || Iter: 140/184 || Loss: 0.05847015231847763 Epoch: 4/5 || Iter: 150/184 || Loss: 0.1348450481891632 Epoch: 4/5 || Iter: 160/184 || Loss: 0.20238502323627472 Epoch: 4/5 || Iter: 170/184 || Loss: 0.06471627205610275 Epoch: 4/5 || Iter: 180/184 || Loss: 0.06523995101451874 Epoch: 5/5 || Iter: 0/184 || Loss: 0.07207831740379333 Epoch: 5/5 || Iter: 10/184 || Loss: 0.0672682523727417 Epoch: 5/5 || Iter: 20/184 || Loss: 0.08528759330511093 Epoch: 5/5 || Iter: 30/184 || Loss: 0.11275054514408112 Epoch: 5/5 || Iter: 40/184 || Loss: 0.07174934446811676 Epoch: 5/5 || Iter: 50/184 || Loss: 0.04678089916706085 Epoch: 5/5 || Iter: 60/184 || Loss: 0.07671410590410233 Epoch: 5/5 || Iter: 70/184 || Loss: 0.09693682193756104 Epoch: 5/5 || Iter: 80/184 || Loss: 0.061378393322229385 Epoch: 5/5 || Iter: 90/184 || Loss: 0.06906907260417938 Epoch: 5/5 || Iter: 100/184 || Loss: 0.06854525208473206 Epoch: 5/5 || Iter: 110/184 || Loss: 0.050054144114255905 Epoch: 5/5 || Iter: 120/184 || Loss: 0.07530767470598221 Epoch: 5/5 || Iter: 130/184 || Loss: 0.04739246517419815 Epoch: 5/5 || Iter: 140/184 || Loss: 0.0572473481297493 Epoch: 5/5 || Iter: 150/184 || Loss: 0.10901613533496857 Epoch: 5/5 || Iter: 160/184 || Loss: 0.07046938687562943 Epoch: 5/5 || Iter: 170/184 || Loss: 0.07678557932376862 Epoch: 5/5 || Iter: 180/184 || Loss: 0.08337099850177765
Протестируйте обе модели, сравните метрики:
net.load_state_dict(torch.load('./CELoss.pth_5.pth'))
net.eval()
PSPNet(
(encoder): EncoderBlock(
(encoder_main): Sequential(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(encoder_supp): Sequential(
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
)
(decoder): DecoderBlock(
(PPM): PyramidPoolingModule(
(bins): ModuleList(
(0): Sequential(
(0): AdaptiveAvgPool2d(output_size=1)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(1): Sequential(
(0): AdaptiveAvgPool2d(output_size=2)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(2): Sequential(
(0): AdaptiveAvgPool2d(output_size=3)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(3): Sequential(
(0): AdaptiveAvgPool2d(output_size=6)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
)
(SM): SupplementaryModule(
(suppl): Sequential(
(0): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout2d(p=0.1, inplace=False)
(4): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
)
)
(UM): UpsampleModule(
(upsample): Sequential(
(0): Upsample(
(us_transform): Sequential(
(0): Conv2d(548, 274, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(274, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(1): Upsample(
(us_transform): Sequential(
(0): Conv2d(274, 137, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(137, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(2): Upsample(
(us_transform): Sequential(
(0): Conv2d(137, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
)
)
(head): SegmentationHead(
(segmentation_head): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout2d(p=0.0, inplace=False)
(4): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
)
)
)
dl_prediction, dl_target = net.test_model(test_dataloader)
dl_prediction, dl_target = net.test_model(test_dataloader)
print("Mean IoU metric: ", metric_class.IoUMetric(dl_prediction, dl_target))
print("Mean Recall metric: ", metric_class.RecallMetric(dl_prediction, dl_target))
Mean IoU metric: tensor(0.9417) Mean Recall metric: tensor(0.9694)
Примеры работы вами обученной сети:
img_idx = np.random.randint(0, 100)
for idx, (input, target) in enumerate(test_dataloader):
if (idx < img_idx):
continue
draw((input[0].squeeze(), target[0].squeeze()), t_dict, dl_prediction[8*idx])
plt.pause(0.1)
if (idx == img_idx+2):
break
Ответьте на вопрос: Как выбор функции потерь влияет на рассчитываемые метрики в ходе обучения?
Ваш ответ: с CrossEntropyLoss метрики в ходе обучения (IOU, Recall) лучше, чем с FocalLoss. Тут нет какого-то строгого обоснования почему, ведь вроде FocalLoss теоретически должен помогать при дисбалансе (а у нас есть дисбаланс пикселей внутри изображения). По всей видимости с нашими данными и нашей моделью этот дисбаланс не настолько существенен, чтобы FocalLoss получало преимущество.
До этого момента мы ни разу не использовали тот факт, что в нашем датасете не бывает слуаев, в которых и собака, и кошка одновременно находятся в кадре. В это же время блок SegmentationHead допускает этот случай, что дает теоретическую возможность модели ошибиться. Чтобы повысить устойчивость модели мы будем использовать две головы: голова двухклассовой сегментации, которая сегментирует животное на изображении, а вторая голова бинарной классификации будет предсказывать, что это за животное (собака или кошка). Таким образом, наша модель не имеет возможности отнести голову животного к классу "собака", а туловище к классу "кошка", что увеличивает ее устойчивость. Реализуйте двуглавый блок SegmentationClassificationHeads.
class SegmentationClassificationHeads(nn.Module):
def __init__(self, in_channels: int, num_classes: int, dropout: float = 0.1):
"""
Вычисляет score для каждого из классов
Вход: тензор (Batch_size, `in_channels`, Height, Width)
Выход: тензор (Batch_size, `num_classes`, Height, Width)
"""
super().__init__()
self.segmentation_head = nn.Sequential(
nn.Conv2d(in_channels, in_channels // 2, kernel_size=3, padding=1),
nn.BatchNorm2d(in_channels // 2),
nn.ReLU(inplace=True),
nn.Dropout2d(dropout),
nn.Conv2d(in_channels // 2, num_classes, kernel_size=1),
)
self.classification_head = nn.Sequential(
nn.Flatten(),
nn.Linear(32, num_classes - 1),
nn.Softmax(dim=1)
)
def combine_heads(self, seg_pred: torch.tensor, cls_pred: torch.tensor) -> torch.tensor:
"""
==== YOUR CODE =====
¯\_(ツ)_/¯
"""
labels_pred = cls_pred.argmax(axis=1)
mask = torch.zeros_like(seg_pred)
mask[:, labels_pred, :, :] = 1
return seg_pred * mask
def forward(self, x: torch.tensor, x_supp: torch.tensor) -> torch.tensor:
"""
Вот мы и воспользовались ранее фиктивным аргументом `x_supp`
"""
cls_pred = self.classification_head(x_supp)
seg_pred = self.segmentation_head(x)
return self.combine_heads(seg_pred, cls_pred)
Обучите двуглавую сеть и получите улучшение метрик относительно наилучшего результата предыдущего пункта:
Mean IoU metric > 0.93Mean Recall metric > 0.96К вашему решению требуется прикрепить логи tensorboard.
train_config["writer"] = SummaryWriter(comment="TwoHead_CEloss") #TwoHead_Floss
train_config["save_model_path"] = 'SegClfHead.pth'
train_config['loss_fn'] = metric_class.CrossEntropyLoss
net = PSPNet(pretrained_model, SegmentationClassificationHeads, num_classes=3, train_config=train_config).to(DEVICE)
print("#параметров в сети:", count_parameters(net))
Pretrained main module conv1 is loaded Pretrained main module bn1 is loaded Pretrained main module relu is loaded Pretrained main module maxpool is loaded Pretrained main module layer1 is loaded Pretrained main module layer2 is loaded Pretrained supp module layer3 is loaded Pretrained supp module layer4 is loaded #параметров в сети: 11759868
net.train_model(train_dataloader_memmap, val_dataloader_memmap)
Epoch: 1/5 || Iter: 0/184 || Loss: 1.006283164024353 Epoch: 1/5 || Iter: 10/184 || Loss: 0.3502597212791443 Epoch: 1/5 || Iter: 20/184 || Loss: 0.3152938485145569 Epoch: 1/5 || Iter: 30/184 || Loss: 0.229232057929039 Epoch: 1/5 || Iter: 40/184 || Loss: 0.32625967264175415 Epoch: 1/5 || Iter: 50/184 || Loss: 0.22314202785491943 Epoch: 1/5 || Iter: 60/184 || Loss: 0.29661625623703003 Epoch: 1/5 || Iter: 70/184 || Loss: 0.17813940346240997 Epoch: 1/5 || Iter: 80/184 || Loss: 0.22324040532112122 Epoch: 1/5 || Iter: 90/184 || Loss: 0.1593639850616455 Epoch: 1/5 || Iter: 100/184 || Loss: 0.2098541557788849 Epoch: 1/5 || Iter: 110/184 || Loss: 0.17502334713935852 Epoch: 1/5 || Iter: 120/184 || Loss: 0.19200189411640167 Epoch: 1/5 || Iter: 130/184 || Loss: 0.1552913933992386 Epoch: 1/5 || Iter: 140/184 || Loss: 0.21559476852416992 Epoch: 1/5 || Iter: 150/184 || Loss: 0.3583393692970276 Epoch: 1/5 || Iter: 160/184 || Loss: 0.17660365998744965 Epoch: 1/5 || Iter: 170/184 || Loss: 0.190147265791893 Epoch: 1/5 || Iter: 180/184 || Loss: 0.1797989159822464 Epoch: 2/5 || Iter: 0/184 || Loss: 0.17468833923339844 Epoch: 2/5 || Iter: 10/184 || Loss: 0.13037434220314026 Epoch: 2/5 || Iter: 20/184 || Loss: 0.194634348154068 Epoch: 2/5 || Iter: 30/184 || Loss: 0.16901983320713043 Epoch: 2/5 || Iter: 40/184 || Loss: 0.23156394064426422 Epoch: 2/5 || Iter: 50/184 || Loss: 0.17543001472949982 Epoch: 2/5 || Iter: 60/184 || Loss: 0.21650846302509308 Epoch: 2/5 || Iter: 70/184 || Loss: 0.1911405324935913 Epoch: 2/5 || Iter: 80/184 || Loss: 0.25600340962409973 Epoch: 2/5 || Iter: 90/184 || Loss: 0.1566762626171112 Epoch: 2/5 || Iter: 100/184 || Loss: 0.15329773724079132 Epoch: 2/5 || Iter: 110/184 || Loss: 0.15566708147525787 Epoch: 2/5 || Iter: 120/184 || Loss: 0.10658662021160126 Epoch: 2/5 || Iter: 130/184 || Loss: 0.14871226251125336 Epoch: 2/5 || Iter: 140/184 || Loss: 0.3607293963432312 Epoch: 2/5 || Iter: 150/184 || Loss: 0.1413830667734146 Epoch: 2/5 || Iter: 160/184 || Loss: 0.30521160364151 Epoch: 2/5 || Iter: 170/184 || Loss: 0.2131950557231903 Epoch: 2/5 || Iter: 180/184 || Loss: 0.1426122933626175 Epoch: 3/5 || Iter: 0/184 || Loss: 0.15959089994430542 Epoch: 3/5 || Iter: 10/184 || Loss: 0.17366008460521698 Epoch: 3/5 || Iter: 20/184 || Loss: 0.21151840686798096 Epoch: 3/5 || Iter: 30/184 || Loss: 0.15909597277641296 Epoch: 3/5 || Iter: 40/184 || Loss: 0.2074342519044876 Epoch: 3/5 || Iter: 50/184 || Loss: 0.13781821727752686 Epoch: 3/5 || Iter: 60/184 || Loss: 0.14632698893547058 Epoch: 3/5 || Iter: 70/184 || Loss: 0.14367517828941345 Epoch: 3/5 || Iter: 80/184 || Loss: 0.12310225516557693 Epoch: 3/5 || Iter: 90/184 || Loss: 0.15726590156555176 Epoch: 3/5 || Iter: 100/184 || Loss: 0.19093354046344757 Epoch: 3/5 || Iter: 110/184 || Loss: 0.13513793051242828 Epoch: 3/5 || Iter: 120/184 || Loss: 0.32211387157440186 Epoch: 3/5 || Iter: 130/184 || Loss: 0.11975201219320297 Epoch: 3/5 || Iter: 140/184 || Loss: 0.17612037062644958 Epoch: 3/5 || Iter: 150/184 || Loss: 0.11519519239664078 Epoch: 3/5 || Iter: 160/184 || Loss: 0.1660498082637787 Epoch: 3/5 || Iter: 170/184 || Loss: 0.15760758519172668 Epoch: 3/5 || Iter: 180/184 || Loss: 0.16032609343528748 Epoch: 4/5 || Iter: 0/184 || Loss: 0.11131583899259567 Epoch: 4/5 || Iter: 10/184 || Loss: 0.3819928765296936 Epoch: 4/5 || Iter: 20/184 || Loss: 0.0925675481557846 Epoch: 4/5 || Iter: 30/184 || Loss: 0.1356377899646759 Epoch: 4/5 || Iter: 40/184 || Loss: 0.18674714863300323 Epoch: 4/5 || Iter: 50/184 || Loss: 0.10063348710536957 Epoch: 4/5 || Iter: 60/184 || Loss: 0.13070477545261383 Epoch: 4/5 || Iter: 70/184 || Loss: 0.18780240416526794 Epoch: 4/5 || Iter: 80/184 || Loss: 0.1717657893896103 Epoch: 4/5 || Iter: 90/184 || Loss: 0.16520006954669952 Epoch: 4/5 || Iter: 100/184 || Loss: 0.14093822240829468 Epoch: 4/5 || Iter: 110/184 || Loss: 0.10212830454111099 Epoch: 4/5 || Iter: 120/184 || Loss: 0.17921021580696106 Epoch: 4/5 || Iter: 130/184 || Loss: 0.12966689467430115 Epoch: 4/5 || Iter: 140/184 || Loss: 0.10844019055366516 Epoch: 4/5 || Iter: 150/184 || Loss: 0.1326400190591812 Epoch: 4/5 || Iter: 160/184 || Loss: 0.17182667553424835 Epoch: 4/5 || Iter: 170/184 || Loss: 0.2265619933605194 Epoch: 4/5 || Iter: 180/184 || Loss: 0.11308327317237854 Epoch: 5/5 || Iter: 0/184 || Loss: 0.16694375872612 Epoch: 5/5 || Iter: 10/184 || Loss: 0.1647510677576065 Epoch: 5/5 || Iter: 20/184 || Loss: 0.13737812638282776 Epoch: 5/5 || Iter: 30/184 || Loss: 0.1677914559841156 Epoch: 5/5 || Iter: 40/184 || Loss: 0.11636580526828766 Epoch: 5/5 || Iter: 50/184 || Loss: 0.11996309459209442 Epoch: 5/5 || Iter: 60/184 || Loss: 0.11455247551202774 Epoch: 5/5 || Iter: 70/184 || Loss: 0.1332285851240158 Epoch: 5/5 || Iter: 80/184 || Loss: 0.1587146818637848 Epoch: 5/5 || Iter: 90/184 || Loss: 0.11577418446540833 Epoch: 5/5 || Iter: 100/184 || Loss: 0.19084015488624573 Epoch: 5/5 || Iter: 110/184 || Loss: 0.1464766561985016 Epoch: 5/5 || Iter: 120/184 || Loss: 0.11565650254487991 Epoch: 5/5 || Iter: 130/184 || Loss: 0.11677568405866623 Epoch: 5/5 || Iter: 140/184 || Loss: 0.13255789875984192 Epoch: 5/5 || Iter: 150/184 || Loss: 0.10625714808702469 Epoch: 5/5 || Iter: 160/184 || Loss: 0.13042517006397247 Epoch: 5/5 || Iter: 170/184 || Loss: 0.13463018834590912 Epoch: 5/5 || Iter: 180/184 || Loss: 0.1401868462562561
Тестируем модель:
net.load_state_dict(torch.load('./SegClfHead.pth_5.pth'))
net.eval()
PSPNet(
(encoder): EncoderBlock(
(encoder_main): Sequential(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer2): Sequential(
(0): Bottleneck(
(conv1): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
(encoder_supp): Sequential(
(layer3): Sequential(
(0): Bottleneck(
(conv1): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(3): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(4): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(5): Bottleneck(
(conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
(layer4): Sequential(
(0): Bottleneck(
(conv1): Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): Bottleneck(
(conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
(2): Bottleneck(
(conv1): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(bn2): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
)
)
)
)
(decoder): DecoderBlock(
(PPM): PyramidPoolingModule(
(bins): ModuleList(
(0): Sequential(
(0): AdaptiveAvgPool2d(output_size=1)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(1): Sequential(
(0): AdaptiveAvgPool2d(output_size=2)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(2): Sequential(
(0): AdaptiveAvgPool2d(output_size=3)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
(3): Sequential(
(0): AdaptiveAvgPool2d(output_size=6)
(1): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1))
(2): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(3): ReLU(inplace=True)
)
)
)
(SM): SupplementaryModule(
(suppl): Sequential(
(0): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout2d(p=0.1, inplace=False)
(4): Conv2d(512, 32, kernel_size=(1, 1), stride=(1, 1))
(5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(6): ReLU(inplace=True)
)
)
(UM): UpsampleModule(
(upsample): Sequential(
(0): Upsample(
(us_transform): Sequential(
(0): Conv2d(548, 274, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(274, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(1): Upsample(
(us_transform): Sequential(
(0): Conv2d(274, 137, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(137, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
(2): Upsample(
(us_transform): Sequential(
(0): Conv2d(137, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
)
)
)
)
)
(head): SegmentationClassificationHeads(
(segmentation_head): Sequential(
(0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Dropout2d(p=0.1, inplace=False)
(4): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1))
)
(classification_head): Sequential(
(0): Flatten(start_dim=1, end_dim=-1)
(1): Linear(in_features=32, out_features=2, bias=True)
(2): Softmax(dim=1)
)
)
)
dl_prediction, dl_target = net.test_model(test_dataloader)
print("Mean IoU metric: ", metric_class.IoUMetric(dl_prediction, dl_target))
print("Mean Recall metric: ", metric_class.RecallMetric(dl_prediction, dl_target))
Mean IoU metric: tensor(0.9452) Mean Recall metric: tensor(0.9723)
Примеры работы вами обученной двуглавой сети:
img_idx = np.random.randint(0, 100)
for idx, (input, target) in enumerate(test_dataloader):
if (idx < img_idx):
continue
draw((input[0].squeeze(), target[0].squeeze()), t_dict, dl_prediction[8*idx])
plt.pause(0.1)
if (idx == img_idx+2):
break